package com.xj.hash;

import java.util.*;

/**
 * Author: xiajun
 * Date: 2015-07-25
 * Time: 17:37:00
 * 一致性hash实现
 */
public class Shard<S> {
    private int VIRTUAL_NOED_SIZE = 1024;//虚拟节点数
    private volatile SortedMap<Long, S> nodes;//虚拟节点和实体的对应
    private List<S> resources;//资源实体节点

    /**
     * @param node              实体节点
     * @param VIRTUAL_NOED_SIZE 设置多少个虚拟节点和一个实体节点对应
     */
    public Shard(List<S> node, int VIRTUAL_NOED_SIZE) {
        if (node == null || node.isEmpty()) {
            throw new IllegalArgumentException("init shard error: node is null.");
        }
        this.VIRTUAL_NOED_SIZE = VIRTUAL_NOED_SIZE;
        this.resources = node;
        init();
    }

    public Shard(List<S> node) {
        this(node, 1024);
    }

    /**
     * 初始化虚拟节点和实体节点的对应关系
     */
    private void init() {
        SortedMap<Long, S> tmp_node = new TreeMap<>();
        for(int i = 0; i < resources.size(); i++) {
            S s = resources.get(i);
            for(int j = 0; j < VIRTUAL_NOED_SIZE; j++) {
                long h = MurmurHash.hash64(("SHARD-" + s.hashCode() + "-" + j));
                tmp_node.put(h, s);
            }
        }
        nodes = tmp_node;
    }

    /**
     * 根据key的hash值获取所在的实体节点
     *
     * @param key
     * @return
     */
    public S getShard(Object key) {
        if (key == null) {
            throw new NullPointerException("key is null.");
        }
        long h = MurmurHash.hash64((String) key);
        SortedMap<Long, S> tail = nodes.tailMap(h);
        if (!tail.isEmpty()) {
            return tail.get(tail.firstKey());
        }
        return nodes.get(nodes.firstKey());
    }

    /**
     * 添加实体节点
     *
     * @param node
     */
    public void addNode(S node) {
        if (node == null) {
            return;
        }
        resources.add(node);
        init();
    }

    /**
     * 删除实体节点
     *
     * @param node
     */
    public void delNode(S node) {
        if (node == null) {
            return;
        }
        Iterator<Long> keys = nodes.keySet().iterator();
        while (keys.hasNext()) {
            Long key = keys.next();
            S node_ = nodes.get(key);
            if (node == node_) {
                keys.remove();
            }
        }
    }

    public static void main(String[] args) {
        List<String> list = new ArrayList<>();
        String map1 = "N1";
        String map2 = "N2";
        String map3 = "N3";
        String map4 = "N4";
        list.add(map1);
        list.add(map2);
        Shard<String> shard = new Shard<>(list);
        shard.addNode(map3);
        shard.addNode(map4);
        Random random = new Random();
        int c1 = 0, c2 = 0, c3 = 0, c4 = 0;
        long s = System.currentTimeMillis();
        for(int i = 0; i < 500000; i++) {
            String map = shard.getShard(String.valueOf(random.nextInt(100000)));
            if (map == map1) {
                c1++;
            } else if (map == map2) {
                c2++;
            } else if (map == map3) {
                c3++;
            } else if (map == map4) {
                c4++;
            }
            if (i == 300000) {
                shard.delNode(map3);
            }
        }
        long e = System.currentTimeMillis();
        System.out.println("c1=" + c1 + "  c2=" + c2 + "  c3=" + c3 + " c4=" + c4);
        float i = 500000.0F;
        System.out.println(c1 / i + "  c2=" + c2 / i + " c3=" + c3 / i + " c4=" + c4 / i);
        System.out.println(e - s);
    }
}
